#!/usr/bin/env python

# This work is licensed under the terms of the MIT license.
# For a copy, see <https://opensource.org/licenses/MIT>.

"""
Red Light Violation
"""

from six.moves.queue import Queue # pylint: disable=relative-import, bad-option-value

import math # pylint: disable=wrong-import-order
import py_trees
import carla
import time
import numpy as np

from srunner.scenariomanager.carla_data_provider import CarlaDataProvider
from srunner.scenariomanager.scenarioatomics.atomic_behaviors import (ActorTransformSetter,
                                                                      ActorDestroy,
                                                                      ActorSink,
                                                                      SyncArrival,
                                                                      KeepVelocity,
                                                                      StopVehicle,
                                                                      WaypointFollower)
from srunner.scenariomanager.scenarioatomics.atomic_criteria import CollisionTest
from srunner.scenariomanager.scenarioatomics.atomic_trigger_conditions import DriveDistance, InTriggerRegion
from srunner.scenarios.basic_scenario import BasicScenario
from srunner.tools.scenario_helper import get_waypoint_in_distance

from envs.scenarios.multiagent_basic_scenario import MultiAgentBasicScenario
from envs.utils.parse_scenario_config import convert_dict_to_transform

class RedLightViolation(MultiAgentBasicScenario):
    """
    Red Light Violation
    """
    def __init__(self, world, ego_vehicles, config, randomize=False, debug_mode=False, criteria_enable=True, timeout=120):
        self.timeout = timeout
        self._world = world
        self._map = CarlaDataProvider.get_map()
        
        self._occluder = None
        self._collider_distance = 70
        self._collider = None

        self._trigger_point_transform = config.trigger_points[0]
        self._other_actor_transform = None
        self._other_actor_target_velocity = 10

        # ego vehicle parameters
        # _ego_vehicle_max_velocity = 20
        self._ego_vehicle_driven_distance = 30 # was 105

        # other vehicle
        self._other_actor_max_brake = 1.0
        self._other_actor_target_velocity = np.random.uniform(low=29, high=31)/3.6 # m/s was 15, 30 makes collision happen
        print("collider target_velocity", self._other_actor_target_velocity)
        self._region_threshold = 25 #  was 3 m

        self._traffic_light = CarlaDataProvider.get_next_traffic_light(ego_vehicles[0], False)
        self._annotations = CarlaDataProvider.annotate_trafficlight_in_group(self._traffic_light)
        RED = carla.TrafficLightState.Red
        YELLOW = carla.TrafficLightState.Yellow
        GREEN = carla.TrafficLightState.Green
        CarlaDataProvider.update_light_states(
                self._traffic_light,
                self._annotations,
                {'ego': GREEN, 'ref': GREEN, 'left': RED, 'right': RED, 'opposite': GREEN},
                freeze=True,
                )

        super(RedLightViolation, self).__init__(
            "RedLightViolation",
            ego_vehicles,
            config,
            world,
            debug_mode,
            criteria_enable=criteria_enable)

    def _initialize_actors(self, config):
        # set up collider
        if config.accident_prone:
            other_actor_transform = convert_dict_to_transform(config.other_actors[0].start)
            other_actor_wp = self._map.get_waypoint(
                                    other_actor_transform.location)

            self._collider_sink_location = convert_dict_to_transform(config.other_actors[0].target).location

            # other_actor_wp = other_actor_wp.get_right_lane()
            other_actor_transform = other_actor_wp.transform
            self._other_actor_transform = other_actor_transform
            first_vehicle_transform = carla.Transform(
                    carla.Location(other_actor_transform.location.x,
                                other_actor_transform.location.y,
                                other_actor_transform.location.z),
                    other_actor_transform.rotation)

            first_vehicle = CarlaDataProvider.request_new_actor('vehicle.audi.tt',
                                                                first_vehicle_transform,
                                                                rolename='scenario')
            first_vehicle.set_simulate_physics(enabled=False)
            self._collider = first_vehicle
            self.other_actors.append(first_vehicle)
        # set up occluder
        dummy_truck_waypoint = CarlaDataProvider.get_map().get_waypoint(
            self._trigger_point_transform.location).get_left_lane()
        if False:
            dummy_truck_waypoint = dummy_truck_waypoint.next(2)[0]
        else:
            dummy_truck_waypoint = dummy_truck_waypoint.next(3)[0]
        dummy_obstacle_blueprint_name = ['vehicle.jeep.wrangler_rubicon', 'vehicle.tesla.cybertruck', 'vehicle.carlamotors.carlacola']

        bp_index = 0
        while not dummy_truck_waypoint.is_intersection:
            print("Trying to spawn actor at {}".format(dummy_truck_waypoint.transform.location))
            dummy_truck_transform = carla.Transform(
                carla.Location(dummy_truck_waypoint.transform.location.x, dummy_truck_waypoint.transform.location.y, 1),
                dummy_truck_waypoint.transform.rotation)
            dummy_truck_blueprint = CarlaDataProvider.create_blueprint('vehicle.carlamotors.carlacola',
                                                                        'scenario_background')
            bp_index += 1
            self.ego_vehicles[0].get_world().try_spawn_actor(dummy_truck_blueprint, dummy_truck_transform)
            dummy_truck_waypoint = dummy_truck_waypoint.next(6)[0]

        self._blocking_truck_transform = carla.Transform(
            carla.Location(dummy_truck_waypoint.transform.location.x,
                           dummy_truck_waypoint.transform.location.y,
                           1),
            carla.Rotation(dummy_truck_waypoint.transform.rotation.pitch,
                           dummy_truck_waypoint.transform.rotation.yaw,
                           dummy_truck_waypoint.transform.rotation.roll))
        print("occluder at:{}".format(self._blocking_truck_transform.location))
        self._occluder = CarlaDataProvider.request_new_actor('vehicle.carlamotors.carlacola',
                                                          self._blocking_truck_transform)
        self._occluder.set_simulate_physics(True)
        print("other_actors", self.other_actors)
        self.other_actors.append(self._occluder)
        print("other_actors after", self.other_actors)


    def _create_behavior(self):
        """
        After invoking this scenario, it will wait for the user
        controlled vehicle to enter the start region,
        then make a traffic participant to accelerate
        until it is going fast enough to reach an intersection point.
        at the same time as the user controlled vehicle at the junction.
        Once the user controlled vehicle comes close to the junction,
        the traffic participant accelerates and passes through the junction.
        After 60 seconds, a timeout stops the scenario.
        """

        if self._other_actor_transform is not None:
            location_of_collision = get_geometric_linear_intersection_by_loc_and_intersection(
                self._trigger_point_transform.location,
                self._other_actor_transform.location)

            if location_of_collision is None:
                print(self._trigger_point_transform.location)
                print(self._other_actor_transform.location)
                raise RuntimeError("Intersecting point doesn't exist")

        start_other_trigger = InTriggerRegion(
             self.ego_vehicles[0],
             -80, -70,
             -75, -60)

        if self._collider is not None:
            sync_arrival = SyncArrival(
                self._collider, self.ego_vehicles[0],
                location_of_collision)

            collider_in_intersection = InTriggerRegion(
                self.ego_vehicles[0],
                location_of_collision.x-self._region_threshold, location_of_collision.x+self._region_threshold,
                location_of_collision.y-2*self._region_threshold, location_of_collision.y)#+self._region_threshold)

            collider_in_intersection = InTriggerRegion(
                self._collider,
                location_of_collision.x-self._region_threshold, location_of_collision.x+self._region_threshold,
                location_of_collision.y-self._region_threshold, location_of_collision.y+self._region_threshold)

            keep_velocity_other = KeepVelocity(
                self._collider,
                self._other_actor_target_velocity)

            collider_sink = ActorSink(
                self._collider_sink_location,
                threshold=1.5
                )
        # stop_other_trigger = InTriggerRegion(
        #     self.other_actors[0],
        #     -45, -35,
        #     -140, -130)

        # stop_other = StopVehicle(
        #     self.other_actors[0],
        #     self._other_actor_max_brake)

        # end_condition = InTriggerRegion(
        #     self.ego_vehicles[0],
        #     -90, -70,
        #     -170, -156
        # )

        #Remove the comment to let the red car move
        if self._collider is not None:
            collider_pass_thru = DriveDistance(self._collider, self._collider_distance)

        ego_drive_distance = DriveDistance(self.ego_vehicles[0], self._ego_vehicle_driven_distance)

        # Creating non-leaf nodes
        # root = py_trees.composites.Sequence()
        scenario_sequence = py_trees.composites.Sequence("AutoCastIntersectionRedLightViolation")
        sync_arrival_parallel = py_trees.composites.Parallel(
            policy=py_trees.common.ParallelPolicy.SUCCESS_ON_ONE)
        keep_velocity_other_parallel = py_trees.composites.Parallel(
            policy=py_trees.common.ParallelPolicy.SUCCESS_ON_ONE)

        if self._collider is not None:
            scenario_sequence.add_child(ActorTransformSetter(self._collider, self._other_actor_transform))
            # scenario_sequence.add_child(start_other_trigger)
            scenario_sequence.add_child(sync_arrival_parallel)
            scenario_sequence.add_child(keep_velocity_other_parallel)

            # scenario_sequence.add_child(stop_other)
            # scenario_sequence.add_child(end_condition)
            # scenario_sequence.add_child(ActorDestroy(self.other_actors[0]))

            sync_arrival_parallel.add_child(sync_arrival)
            sync_arrival_parallel.add_child(collider_in_intersection)
            keep_velocity_other_parallel.add_child(keep_velocity_other)
            keep_velocity_other_parallel.add_child(ego_drive_distance)
            # keep_velocity_other_parallel.add_child(collider_pass_thru)
            # keep_velocity_other_parallel.add_child(stop_other_trigger)
            keep_velocity_other_parallel.add_child(collider_sink)

        scenario_sequence.add_child(ego_drive_distance)
        return scenario_sequence

    def _setup_scenario_trigger(self, config):
        return None

    def _create_test_criteria(self):
        """
        A list of all test criteria will be created that is later used
        in parallel behavior tree.
        """
        criteria = []

        # Adding checks for ego vehicle
        #collision_criterion_ego = CollisionTest(self.ego_vehicles[0])
        #driven_distance_criterion = DrivenDistanceTest(
        #    self.ego_vehicles[0], self._ego_vehicle_driven_distance)
        #criteria.append(collision_criterion_ego)
        #criteria.append(driven_distance_criterion)

        # Add approriate checks for other vehicles
        for vehicle in self.other_actors:
           collision_criterion = CollisionTest(vehicle, name="CollisionTest", terminate_on_failure=True)
           criteria.append(collision_criterion)

        return criteria

    def __del__(self):
        """
        Remove all actors upon deletion
        """
        self.remove_all_actors()

def get_geometric_linear_intersection_by_loc_and_intersection(ego_actor_loc, other_actor_loc):
    """
    Obtain a intersection point between two actor's location by using their waypoints (wp)

    @return point of intersection of the two vehicles
    """

    wp_ego_1 = CarlaDataProvider.get_map().get_waypoint(ego_actor_loc)
    wp_ego_2 = wp_ego_1.next(2)[0]

    while not wp_ego_2.is_intersection:
        wp_ego_2 = wp_ego_2.next(2)[0]

    x_ego_1 = wp_ego_1.transform.location.x
    y_ego_1 = wp_ego_1.transform.location.y
    x_ego_2 = wp_ego_2.transform.location.x
    y_ego_2 = wp_ego_2.transform.location.y

    # print("get_geometric_linear_intersection_by_loc_and_intersection ego 1 {} {}".format(x_ego_1, y_ego_1))
    # print("get_geometric_linear_intersection_by_loc_and_intersection ego 2 {} {}".format(x_ego_2, y_ego_2))
    wp_other_1 = CarlaDataProvider.get_world().get_map().get_waypoint(other_actor_loc)
    wp_other_2 = wp_other_1.next(2)[0]
    while not wp_other_2.is_intersection:
        wp_other_2 = wp_other_2.next(2)[0]

    x_other_1 = wp_other_1.transform.location.x
    y_other_1 = wp_other_1.transform.location.y
    x_other_2 = wp_other_2.transform.location.x
    y_other_2 = wp_other_2.transform.location.y
    # print("get_geometric_linear_intersection_by_loc_and_intersection actor 1 {} {}".format(x_other_1, y_other_1))
    # print("get_geometric_linear_intersection_by_loc_and_intersection actor 2 {} {}".format(x_other_2, y_other_2))

    s = np.vstack([(x_ego_1, y_ego_1), (x_ego_2, y_ego_2), (x_other_1, y_other_1), (x_other_2, y_other_2)])
    h = np.hstack((s, np.ones((4, 1))))
    line1 = np.cross(h[0], h[1])
    line2 = np.cross(h[2], h[3])
    x, y, z = np.cross(line1, line2)
    if z == 0:
        return None

    intersection = carla.Location(x=x / z, y=y / z, z=0)

    return intersection
